#!/usr/bin/env python3
"""A script to generate FileCheck statements for mlir unit tests.

This script is a utility to add FileCheck patterns to an mlir file.

NOTE: The input .mlir is expected to be the output from the parser, not a
stripped down variant.

Example usage:
$ generate-test-checks.py foo.mlir
$ mlir-opt foo.mlir -transformation | generate-test-checks.py
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'

The script will heuristically generate CHECK/CHECK-LABEL commands for each line
within the file. By default this script will also try to insert string
substitution blocks for all SSA value names. If --source file is specified, the
script will attempt to insert the generated CHECKs to the source file by looking
for line positions matched by --source_delim_regex.

The script is designed to make adding checks to a test case fast, it is *not*
designed to be authoritative about what constitutes a good test!
"""

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
import os  # Used to advertise this file's name ("autogenerated_note").
import re
import sys
from collections import Counter

ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
ADVERT_END = """
// This script is intended to make adding checks to a test case quick and easy.
// It is *not* authoritative about what constitutes a good test. After using the
// script, be sure to review and refine the generated checks. For example,
// CHECK lines should be minimized and named to reflect the test’s intent.
// For comprehensive guidelines, see:
//   * https://mlir.llvm.org/getting_started/TestingGuide/
"""


# Regex command to match an SSA identifier.
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
SSA_RE = re.compile(SSA_RE_STR)

# Regex matching `dialect.op_name` (e.g. `vector.transfer_read`).
SSA_OP_NAME_RE = re.compile(r"\b(?:\s=\s[a-z_]+)[.]([a-z_]+)\b")

# Regex matching the left-hand side of an assignment
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)

# Regex matching attributes
ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
ATTR_RE = re.compile(ATTR_RE_STR)

# Regex matching the left-hand side of an attribute definition
ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)


# Class used to generate and manage string substitution blocks for SSA value
# names.
class VariableNamer:
    def __init__(self, variable_names):
        self.scopes = []
        # Counter for generic FileCHeck names, e.g. VAL_#N
        self.name_counter = 0
        # Counters for FileCheck names derived from Op names, e.g.
        # TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
        # dedicated counter for every Op type present in the input.
        self.op_name_counter = Counter()

        # Number of variable names to still generate in parent scope
        self.generate_in_parent_scope_left = 0

        # Parse variable names
        self.variable_names = [name.upper() for name in variable_names.split(',')]
        self.used_variable_names = set()

    # Generate the following 'n' variable names in the parent scope.
    def generate_in_parent_scope(self, n):
        self.generate_in_parent_scope_left = n

    # Generate a substitution name for the given ssa value name.
    def generate_name(self, source_variable_name, use_ssa_name, op_name=""):

        # Compute variable name
        variable_name = (
            self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
        )
        if variable_name == "":
            # If `use_ssa_name` is set, use the MLIR SSA value name to generate
            # a FileCHeck substation string. As FileCheck requires these
            # strings to start with a character, skip MLIR variables starting
            # with a digit (e.g. `%0`).
            #
            # The next fallback option is to use the op name, if the
            # corresponding match succeeds.
            #
            # If neither worked, use a generic name: `VAL_#N`.
            if use_ssa_name and source_variable_name[0].isalpha():
                variable_name = source_variable_name.upper()
            elif op_name != "":
                variable_name = (
                    op_name.upper() + "_" + str(self.op_name_counter[op_name])
                )
                self.op_name_counter[op_name] += 1
            else:
                variable_name = "VAL_" + str(self.name_counter)
                self.name_counter += 1

        # Scope where variable name is saved
        scope = len(self.scopes) - 1
        if self.generate_in_parent_scope_left > 0:
            self.generate_in_parent_scope_left -= 1
            scope = len(self.scopes) - 2
        assert(scope >= 0)

        # Save variable
        if variable_name in self.used_variable_names:
            raise RuntimeError(variable_name + ': duplicate variable name')
        self.scopes[scope][source_variable_name] = variable_name
        self.used_variable_names.add(variable_name)

        return variable_name

    # Push a new variable name scope.
    def push_name_scope(self):
        self.scopes.append({})

    # Pop the last variable name scope.
    def pop_name_scope(self):
        self.scopes.pop()

    # Return the level of nesting (number of pushed scopes).
    def num_scopes(self):
        return len(self.scopes)

    # Reset the counter and used variable names.
    def clear_names(self):
        self.name_counter = 0
        self.used_variable_names = set()
        self.op_name_counter.clear()

class AttributeNamer:

    def __init__(self, attribute_names):
        self.name_counter = 0
        self.attribute_names = [name.upper() for name in attribute_names.split(',')]
        self.map = {}
        self.used_attribute_names = set()

    # Generate a substitution name for the given attribute name.
    def generate_name(self, source_attribute_name):

        # Compute FileCheck name
        attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
        if attribute_name == '':
            attribute_name = "ATTR_" + str(self.name_counter)
            self.name_counter += 1

        # Prepend global symbol
        attribute_name = '$' + attribute_name

        # Save attribute
        if attribute_name in self.used_attribute_names:
            raise RuntimeError(attribute_name + ': duplicate attribute name')
        self.map[source_attribute_name] = attribute_name
        self.used_attribute_names.add(attribute_name)
        return attribute_name

    # Get the saved substitution name for the given attribute name. If no name
    # has been generated for the given attribute yet, None is returned.
    def get_name(self, source_attribute_name):
        return self.map.get(source_attribute_name)

# Return the number of SSA results in a line of type
#   %0, %1, ... = ...
# The function returns 0 if there are no results.
def get_num_ssa_results(input_line):
    m = SSA_RESULTS_RE.match(input_line)
    return m.group().count('%') if m else 0


# Process a line of input that has been split at each SSA identifier '%'.
def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
    output_line = ""

    # Process the rest that contained an SSA value name.
    for chunk in line_chunks:
        ssa = SSA_RE.match(chunk)
        op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
        ssa_name = ssa.group(0) if ssa is not None else ""
        op_name = (
            op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
        )

        # Check if an existing variable exists for this name.
        variable = None
        for scope in variable_namer.scopes:
            variable = scope.get(ssa_name)
            if variable is not None:
                break

        # If one exists, then output the existing name.
        if variable is not None:
            output_line += "%[[" + variable + "]]"
        else:
            # Otherwise, generate a new variable.
            variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
            if strict_name_re:
                # Use stricter regexp for the variable name, if requested.
                # Greedy matching may cause issues with the generic '.*'
                # regexp when the checks are split across several
                # lines (e.g. for CHECK-SAME).
                output_line += "%[[" + variable + ":" + SSA_RE_STR + "]]"
            else:
                output_line += "%[[" + variable + ":.*]]"

        # Append the non named group.
        output_line += chunk[len(ssa_name) :]

    return output_line.rstrip() + "\n"


# Process the source file lines. The source file doesn't have to be .mlir.
def process_source_lines(source_lines, note, args):
    source_split_re = re.compile(args.source_delim_regex)

    source_segments = [[]]
    for line in source_lines:
        # Remove previous note.
        if line in note:
            continue
        # Remove previous CHECK lines.
        if line.find(args.check_prefix) != -1:
            continue
        # Segment the file based on --source_delim_regex.
        if source_split_re.search(line):
            source_segments.append([])

        source_segments[-1].append(line + "\n")
    return source_segments


def process_attribute_definition(line, attribute_namer):
    m = ATTR_DEF_RE.match(line)
    if m:
        attribute_name = attribute_namer.generate_name(m.group(1))
        return (
            "// CHECK: #[["
            + attribute_name
            + ":.+]] ="
            # The rest of the line may contain attribute references,
            # so we have to process them.
            + process_attribute_references(line[len(m.group(0)) :], attribute_namer)
            + "\n"
        )
    return None

def process_attribute_references(line, attribute_namer):

    output_line = ''
    components = ATTR_RE.split(line)
    for component in components:
        m = ATTR_RE.match(component)
        attribute_name = attribute_namer.get_name(m.group(1)) if m else None
        if attribute_name:
            output_line += f"#[[{attribute_name}]]{component[len(m.group()):]}"
        else:
            output_line += component
    return output_line

# Pre-process a line of input to remove any character sequences that will be
# problematic with FileCheck.
def preprocess_line(line):
    # Replace any `{{` with escaped replacements. `{{` corresponds to regex
    # checks in FileCheck.
    output_line = line.replace("{{", "{{\\{\\{}}")

    # Replace any double brackets, '[[' with escaped replacements. '[['
    # corresponds to variable names in FileCheck.
    output_line = output_line.replace("[[", "{{\\[\\[}}")

    # Replace any single brackets that are followed by an SSA identifier, the
    # identifier will be replace by a variable; Creating the same situation as
    # above.
    output_line = output_line.replace("[%", "{{\\[}}%")

    return output_line


def main():
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument(
        "--check-prefix", default="CHECK", help="Prefix to use from check file."
    )
    parser.add_argument(
        "-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
    )
    parser.add_argument(
        "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
    )
    parser.add_argument(
        "--source",
        type=str,
        help="Print each CHECK chunk before each delimeter line in the source"
        "file, respectively. The delimeter lines are identified by "
        "--source_delim_regex.",
    )
    parser.add_argument("--source_delim_regex", type=str, default="func @")
    parser.add_argument(
        "--starts_from_scope",
        type=int,
        default=1,
        help="Omit the top specified level of content. For example, by default "
        'it omits "module {"',
    )
    parser.add_argument("-i", "--inplace", action="store_true", default=False)
    parser.add_argument(
        "--variable_names",
        type=str,
        default='',
        help="Names to be used in FileCheck regular expression to represent SSA "
        "variables in the order they are encountered. Separate names with commas, "
        "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
    parser.add_argument(
        "--attribute_names",
        type=str,
        default='',
        help="Names to be used in FileCheck regular expression to represent "
        "attributes in the order they are defined. Separate names with commas,"
        "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
    parser.add_argument(
        "--strict_name_re",
        type=bool,
        default=False,
        help="Set to true to use stricter regex for CHECK-SAME directives. "
        "Use when Greedy matching causes issues with the generic '.*'",
    )

    args = parser.parse_args()

    # Open the given input file.
    input_lines = [l.rstrip() for l in args.input]
    args.input.close()

    # Generate a note used for the generated check file.
    script_name = os.path.basename(__file__)
    autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END

    source_segments = None
    if args.source:
        source_segments = process_source_lines(
            [l.rstrip() for l in open(args.source, "r")], autogenerated_note, args
        )

    if args.inplace:
        assert args.output is None
        output = open(args.source, "w")
    elif args.output is None:
        output = sys.stdout
    else:
        output = args.output

    output_segments = [[]]

    # Namers
    variable_namer = VariableNamer(args.variable_names)
    attribute_namer = AttributeNamer(args.attribute_names)

    # Store attribute definitions to emit at appropriate scope
    pending_attr_defs = []

    # Process lines
    for input_line in input_lines:
        if not input_line:
            continue

        # When using `--starts_from_scope=0` to capture module lines, the file
        # split needs to be skipped, otherwise a `CHECK: // -----` is inserted.
        if input_line.startswith("// -----"):
            continue

        if ATTR_DEF_RE.match(input_line):
            pending_attr_defs.append(input_line)
            continue

        # Lines with blocks begin with a ^. These lines have a trailing comment
        # that needs to be stripped.
        lstripped_input_line = input_line.lstrip()
        is_block = lstripped_input_line[0] == "^"
        if is_block:
            input_line = input_line.rsplit("//", 1)[0].rstrip()

        cur_level = variable_namer.num_scopes()

        # If the line starts with a '}', pop the last name scope.
        if lstripped_input_line[0] == "}":
            variable_namer.pop_name_scope()
            cur_level = variable_namer.num_scopes()

        # If the line ends with a '{', push a new name scope.
        if input_line[-1] == "{":
            variable_namer.push_name_scope()
            if cur_level == args.starts_from_scope:
                output_segments.append([])

            # Result SSA values must still be pushed to parent scope
            num_ssa_results = get_num_ssa_results(input_line)
            variable_namer.generate_in_parent_scope(num_ssa_results)

        # Omit lines at the near top level e.g. "module {".
        if cur_level < args.starts_from_scope:
            continue

        if len(output_segments[-1]) == 0:
            variable_namer.clear_names()

        # Preprocess the input to remove any sequences that may be problematic with
        # FileCheck.
        input_line = preprocess_line(input_line)

        # Process uses of attributes in this line
        input_line = process_attribute_references(input_line, attribute_namer)

        # Split the line at the each SSA value name.
        ssa_split = input_line.split("%")

        # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
        if len(output_segments[-1]) != 0 or not ssa_split[0]:
            output_line = "// " + args.check_prefix + ": "
            # Pad to align with the 'LABEL' statements.
            output_line += " " * len("-LABEL")

            # Output the first line chunk that does not contain an SSA name.
            output_line += ssa_split[0]

            # Process the rest of the input line.
            output_line += process_line(ssa_split[1:], variable_namer)

        else:
            # Emit any pending attribute definitions at the start of this scope
            for attr in pending_attr_defs:
                attr_line = process_attribute_definition(attr, attribute_namer)
                if attr_line:
                    output_segments[-1].append(attr_line)
            pending_attr_defs.clear()

            # Output the first line chunk that does not contain an SSA name for the
            # label.
            output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n"

            # Process the rest of the input line on separate check lines.
            for argument in ssa_split[1:]:
                output_line += "// " + args.check_prefix + "-SAME:  "

                # Pad to align with the original position in the line (i.e. where the label ends),
                # unless the label is more than 20 chars long, in which case pad with 4 spaces
                # (this is to avoid deep indentation).
                label_length = len(ssa_split[0])
                pad_depth = label_length if label_length < 21 else 4
                output_line += " " * pad_depth

                # Process the rest of the line. Use the original SSA name to generate the LIT
                # variable names.
                use_ssa_names = True
                output_line += process_line(
                    [argument], variable_namer, use_ssa_names, args.strict_name_re
                )

        # Append the output line.
        output_segments[-1].append(output_line)

    output.write(autogenerated_note + "\n")

    # Write the output.
    if source_segments:
        assert len(output_segments) == len(source_segments)
        for check_segment, source_segment in zip(output_segments, source_segments):
            for line in check_segment:
                output.write(line)
            for line in source_segment:
                output.write(line)
    else:
        for segment in output_segments:
            output.write("\n")
            for output_line in segment:
                output.write(output_line)
        output.write("\n")
    output.close()


if __name__ == "__main__":
    main()
